import logging
import pickle

import cv2
import numpy as np
from flask import jsonify, request
from werkzeug.exceptions import BadRequest

from app import facenet, anti_spoofing_two_stream_vit
from app.model.models import User, Photo, FaceModel
from app.recognition import compare
from app.utils import file_utils
from . import bp
from app.utils import retinex

LOG = logging.getLogger(__name__)


@bp.route('/face/recognition', methods=['POST'])
def face_recognition():
    data = request.json or {}
    if 'photo' not in data:
        raise BadRequest('Must include photo.')
    photo = data['photo']
    turn_on_anti_spoofing = data['turn_on_anti_spoofing'] if 'turn_on_anti_spoofing' in data else False

    # face anti-spoofing
    if turn_on_anti_spoofing:
        image = file_utils.resize_blob_to(blob=photo, size=224)
        image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image_gray = cv2.cvtColor(image_gray, cv2.COLOR_GRAY2BGR)
        image_msr = retinex.MSRCR(image_gray, sigma_list=[15, 80, 250], G=5.0, b=25.0, alpha=125.0,
                                  beta=46.0, low_clip=0.01, high_clip=0.99)
        image_msr = cv2.cvtColor(image_msr, cv2.COLOR_BGR2GRAY)
        y_pred = anti_spoofing_two_stream_vit.get_y_prediction(x=image, msr=image_msr)
        if y_pred == 1:
            user = User()
            user.from_dict({'id': 0, 'username': '非活体', 'cell_phone_number': ''})
            return jsonify(success=True, data=user.to_dict(), status_code=200)

    # face recognition
    image = file_utils.resize_blob_to(photo)
    embedding = facenet.get_embeddings(image)
    photos = Photo.query.all()
    x, y = [], []
    for p in photos:
        x.append(pickle.loads(p.embedding))
        y.append(p.user_id)
    x = np.array(x)
    y = np.array(y)
    # LOG.info([embedding.shape, x.shape, y.shape])
    target = compare.face_net_compare(embedding, x, y)
    if target == 0:
        user = User()
        user.from_dict({'id': 0, 'username': '无法识别', 'cell_phone_number': ''})
    else:
        user = User.query.filter_by(id=int(target)).first()
    return jsonify(success=True, data=user.to_dict(), status_code=200)


@bp.route('face/models/<int:task_id>', methods=['GET'])
def get_face_models(task_id):
    if not task_id:
        raise BadRequest()
    face_models = FaceModel.query.filter_by(task_id=task_id).all()
    result = [{'label': face_model.model_name, 'value': face_model.id} for face_model in face_models]
    return jsonify(success=True, data=result, status_code=200)
